//
// Created by song on 17-1-16.
//

#ifndef C0COMPILER_OPTIMIZEDMIPSGEN_H
#define C0COMPILER_OPTIMIZEDMIPSGEN_H


#include <set>
#include "MIPsIns.h"
#include "IR.h"
#include "Optimizer.h"

class OptimizedIRVisitor: public IRVisitor {

public:
    OptimizedIRVisitor(set<TmpVar *> globalVars, MIPsIns &ins, Optimizer &optimizer) :
            globalVars(globalVars),
            a(ins),
            opt(optimizer)
    {
        paramCount = 1;
    }


private:
    TmpVar* functionLabel;
    TmpVar* left;
    TmpVar* right;
    TmpVar* result;
    Optimizer & opt;
    MIPsIns &a;
    int paramCount;
    set<TmpVar*> globalVars;
protected:
    ReturnedObj *add(IR *ir) override {
        load(left, T2);
        load(right,T3);
        a.add(T1, T2, T3);
        saveStackVar(result, T1);
    }

    ReturnedObj *sub(IR *ir) override {
        load(left, T2);
        load(right, T3);
        a.sub(T1, T2, T3);
        saveStackVar(result, T1);
    }

    ReturnedObj *multiple(IR *ir) override {
        load(left, T2);
        load(right,T3);
        a.mul(T1, T2, T3);
        saveStackVar(result, T1);
    }

    ReturnedObj *division(IR *ir) override {
        load(left, T2);
        load(right,T3);
        a.div(T1, T2, T3);
        saveStackVar(result, T1);
    }

    ReturnedObj *jmp(IR *ir) override {
        a.j(result->name());
    }

    ReturnedObj *jge(IR *ir) override {
        load(left, T1);
        load(right,T2);
        a.bge(T1, T2, result->name());
    }

    ReturnedObj *jg(IR *ir) override {
        load(left, T1);
        load(right,T2);
        a.bgt(T1, T2, result->name());
    }

    ReturnedObj *jne(IR *ir) override {
        load(left, T1);
        load(right,T2);
        a.bne(T1, T2, result->name());
    }

    ReturnedObj *je(IR *ir) override {
        load(left, T1);
        load(right,T2);
        a.beq(T1, T2, result->name());
    }

    ReturnedObj *jle(IR *ir) override {
        load(left, T1);
        load(right,T2);
        a.ble(T1, T2, result->name());
    }

    ReturnedObj *jl(IR *ir) override {
        load(left, T1);
        load(right,T2);
        a.bge(T1, T2, result->name());
    }

    ReturnedObj *assign(IR *ir) override {
        load(left, T0);
        if(globalVars.count(result)>0){ //global var
            a.sw( T0, result->name());
        }else{
            saveStackVar(result, T0);
        }
    }

    ReturnedObj *readInt(IR *ir) override {
        a.li(V0, "READ_INT");
        a.syscall();
        if(globalVars.count(result)>0){
            a.sw( V0, result->name() );
        }else{
            saveStackVar(result, V0);
        }
    }

    ReturnedObj *readChar(IR *ir) override {
        a.li ( V0, "READ_CHAR" );
        a.syscall();
        if(globalVars.count(result)>0){
            a.sw( V0, result->name());
        }else{
            saveStackVar(result, V0);
        }
    }

    ReturnedObj *writeStr(IR *ir) override {
        a.la(A0, left->name());
        a.li(V0, "PRINT_STR");
        a.syscall();
    }

    ReturnedObj *writeInt(IR *ir) override {
        load(left, A0);
        a.li(V0, "PRINT_INT");
        a.syscall();
    }

    ReturnedObj *writeChar(IR *ir) override {
        load(left, A0);
        a.li( V0, "PRINT_CHAR");
        a.syscall();
    }

    ReturnedObj *writeln(IR *ir) override {
        a.li( A0, ((int)'\n') );a.comment("print new line");
        a.li( V0, "PRINT_CHAR");
        a.syscall();
    }

    ReturnedObj *globalArray(IR *ir) override {

    }

    ReturnedObj *stackArray(IR *ir) override {

    }

    ReturnedObj *noop(IR *ir) override {

    }

    ReturnedObj *pushParam(IR *ir) override {
        load(left, T0);
        a.sw(T0, SP, (-4) * (paramCount + 1));
        paramCount++;
    }

    ReturnedObj *popParam(IR *ir) override {

    }

    ReturnedObj *error(IR *ir) override {
        a.la(A0, "_err");
        a.li(V0, 4);
        a.syscall();
        load(left, A0);
        a.li(V0, 1);
        a.syscall();
        load(left, A0);
        a.li(V0, 17);
        a.syscall();
    }

    ReturnedObj *call(IR *ir) override {
        a.jal(left->name());
        if(result!=NULL){
            saveStackVar(result, V0);
        }
        paramCount=1;
    }

    ReturnedObj *functionReturn(IR *ir) override {
        if(ir->left!=NULL){
            load(left, V0);
        }
        a.addiu(SP, SP, functionLabel->value);
        a.lw(  RA, SP, -4 );
        a.jr( RA );
    }

    ReturnedObj *deref(IR *ir) override {
        // offset in t1
        load(right, T1);
        a.sll(T1, T1, 2); // shift 2 (logical) bit == multiple by 4
        if(globalVars.count(left)){ //global var
            a.lw(T0, left->name(), T1);
        }else{
            a.addi( T1, SP, left->value);
            a.lw( T0, T1);
        }
        saveStackVar(result, T0);
    }

    ReturnedObj *assignArray(IR *ir) override {
        load(result, T0);
        load(right, T1);
        a.sll(T1, T1, 2); // shift left (logical) 2 bit == multiple by 4
        if(globalVars.count(left)){ //global var
            a.sw( T0, left->name(), T1);
        }else{
            a.addi(T1, SP, left->value); //add offset of array in function frame
            a.sw( T0, T1);
        }
    }

    ReturnedObj *functionBegin(IR *ir) override {
        functionLabel = ir->label;
        a.sw(RA, SP, -4);
        a.subiu(SP, SP, functionLabel->value);
        return NULL;
    }

    ReturnedObj *functionEnd(IR *ir) override {

    }

    ReturnedObj *sge(IR *ir) override {
        load(left, T2);
        load(right,T3);
        a.sge(T1, T2, T3);
        saveStackVar(result,T1);
    }

    ReturnedObj *sg(IR *ir) override {
        load(left, T2);
        load(right,T3);
        a.sgt(T1, T2, T3);
        saveStackVar(result,T1);
    }

    ReturnedObj *sne(IR *ir) override {
        load(left, T2);
        load(right,T3);
        a.sne(T1, T2, T3);
        saveStackVar(result,T1);
    }

    ReturnedObj *se(IR *ir) override {
        load(left, T2);
        load(right,T3);
        a.seq(T1, T2, T3);
        saveStackVar(result,T1);
    }

    ReturnedObj *sle(IR *ir) override {
        load(left, T2);
        load(right,T3);
        a.sle(T1, T2, T3);
        saveStackVar(result,T1);
    }

    ReturnedObj *sl(IR *ir) override {
        load(left, T2);
        load(right,T3);
        a.slt(T1, T2, T3);
        saveStackVar(result,T1);
    }

    void load(TmpVar* v, Register target){
        switch(v->type){
            case TmpVar::VAR :
                if(globalVars.count(v)>0){ // global variable
                    a.lw( target, v->name());
                }else{
                    a.lw( target, SP, v->value );
                    a.comment(v->name());
                }
                break;
            case TmpVar::VALUE:
                a.li( target , v->value ) ;
                break;
            default:
                Error::nextErrorDetail<< "try load a var type is not var or value";
                Error::internal(Error::Should_Not_Happen);
        }
    }

    void saveStackVar(TmpVar* v, Register source){
        if(v->type==TmpVar::VAR) {
            if (globalVars.count(v)==0) { // local variable
                a.sw(source, SP, v->value);
                a.comment(v->name());
                return;
            }
        }
        Error::nextErrorDetail<< "try load a var type is not var or value";
        Error::internal(Error::Should_Not_Happen);
    }

public:

    ReturnedObj *visit(IR *ir) override {
        left = ir->left;
        right = ir->right;
        result = ir->result;
        return IRVisitor::visit(ir);
    }

};



class OptimizedMipsGen {
    MIPsIns& ins;
    Optimizer& opt;
    vector<IR*> irList;
    set<string> stringSet;
    set<TmpVar*> globalVars;
    set<TmpVar*> localVarSet;
public:

    OptimizedMipsGen(IRList* irList, MIPsIns& ins, Optimizer &opt):ins(ins),opt(opt){
        this->irList = irList->clist;
    }

    void transform() {
        ins.define("PRINT_INT", "1");
        ins.define("PRINT_STR", "4");
        ins.define("PRINT_CHAR","11");
        ins.define("READ_INT",  "5");
        ins.define("READ_CHAR", "12");
        ins.dataSegmentBegin();
        ins.globalString("_err","\"runtime error: \"");

        for(int i=0; i< TmpVar::varList.size(); i++){
            TmpVar* var = TmpVar::varList[i];
            if(var->type==TmpVar::STRING){
                ins.globalString(var->name(), var->content);
            }
        }


        if(Config::irSymbolTable) cout << "#########################################" << endl;
        bool global=true;
        int functionStart = 0;
        int offset = -1; //offset to caller's $sp, 1 means 1 word(4bytes). -1 now points to $ra
        TmpVar* functionLabel;
        vector<TmpVar*> localVarList;
        for(int i=0; i<irList.size(); i++) {
            IR *ir = irList[i];
            if (global) {
                switch (ir->op) {
                    case ASSIGN:
                        ins.globalVariable(ir->result->name(), ir->left->value);
                        globalVars.insert(ir->result); // global variable mark.
                        break;
                    case ARRAY:
                        ins.globalEmptyArray(ir->result->name(), ir->left->value * 4);
                        globalVars.insert(ir->result); // global variable mark.
                        break;
                    case JMP:
                        ins.textSegmentBegin();
                        ins.j(ir->result->name());
                        global = false;
                        functionStart = i + 1;
                        break;
                }
            } else {
                switch(ir->op){
                    case FUNBEGIN:
                        offset = -1;
                        functionLabel = ir->label;
                        break;
                    case FUNEND:
                        functionLabel->value = (-offset)*4;
                        for(int j=0;j<localVarList.size();j++){
                            TmpVar* tmpVar = localVarList[j];
                            tmpVar->value = ((-offset) + tmpVar->value) * 4;
                            if(Config::irSymbolTable) cout << tmpVar->debugStr() << endl;
                        }
                        if(Config::irSymbolTable) cout << "--------------------------------" << endl;
                        localVarList.clear();
                        localVarSet.clear();
                        break;
                    case PUSHARR:
                        offset -= ir->left->value;
                        ir->result->value = offset; // start address of array is lower than the end address
                        localVarList.push_back(ir->result);
                        break;
                    default:
                        if (notAlloc(ir->left)){
                            offset--;
                            ir->left->value = offset;
                            localVarList.push_back(ir->left);
                            localVarSet.insert(ir->left);
                        }
                        if (notAlloc(ir->right)){
                            offset--;
                            ir->right->value = offset;
                            localVarList.push_back(ir->right);
                            localVarSet.insert(ir->right);
                        }
                        if (notAlloc(ir->result)){
                            offset--;
                            ir->result->value = offset;
                            localVarList.push_back(ir->result);
                            localVarSet.insert(ir->result);
                        }
                }
            }
        }


        OptimizedIRVisitor irVisitor(globalVars, ins, opt);
        for(int i=functionStart; i<irList.size(); i++){
            IR* ir = irList[i];
            if(ir->label!=NULL){
                ins.label(ir->label->name());
            }
            irVisitor.visit(ir);
        }
        ins.li(A0, 0);
        ins.li(V0, 17);
        ins.syscall();
    }

    string label(IR* ir){
        stringstream r;
        if(ir->label==NULL){
            return "";
        }else{
            r << ir->label->name() << ":" << endl;
            return r.str();
        }
    }

    string blank(){
        return "\t";
    }

    bool notAlloc(TmpVar *var){
        if(var!=NULL){
            if(var->type==TmpVar::VAR){
                if(globalVars.count(var)==0 && localVarSet.count(var)==0){ // not initialized local var
                    return true;
                }
            }
        }
        return false;
    }
};


#endif //C0COMPILER_OPTIMIZEDMIPSGEN_H
